{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "“NNLM-Torch.ipynb”的副本",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "7-FP_Q8vJUtX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# code by Tae Hwan Jung @graykode, modify by wmathor\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optimizer\n",
        "import torch.utils.data as Data\n",
        "\n",
        "dtype = torch.FloatTensor"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5K5cfsDwP6JQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "sentences = ['i like cat', 'i love coffee', 'i hate milk']\n",
        "sentences_list = \" \".join(sentences).split() # ['i', 'like', 'cat', 'i', 'love'. 'coffee',...]\n",
        "vocab = list(set(sentences_list))\n",
        "word2idx = {w:i for i, w in enumerate(vocab)}\n",
        "idx2word = {i:w for i, w in enumerate(vocab)}\n",
        "\n",
        "V = len(vocab)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dzkc00uqQiXT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def make_data(sentences):\n",
        "  input_data = []\n",
        "  target_data = []\n",
        "  for sen in sentences:\n",
        "    sen = sen.split() # ['i', 'like', 'cat']\n",
        "    input_tmp = [word2idx[w] for w in sen[:-1]]\n",
        "    target_tmp = word2idx[sen[-1]]\n",
        "\n",
        "    input_data.append(input_tmp)\n",
        "    target_data.append(target_tmp)\n",
        "  return input_data, target_data"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2kBFTErzRMFy",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "input_data, target_data = make_data(sentences)\n",
        "input_data, target_data = torch.LongTensor(input_data), torch.LongTensor(target_data)\n",
        "dataset = Data.TensorDataset(input_data, target_data)\n",
        "loader = Data.DataLoader(dataset, 2, True)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XlQ46HoTRyxn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# parameters\n",
        "m = 2\n",
        "n_step = 2\n",
        "n_hidden = 10"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "uGURCnXxRav9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class NNLM(nn.Module):\n",
        "  def __init__(self):\n",
        "    super(NNLM, self).__init__()\n",
        "    self.C = nn.Embedding(V, m)\n",
        "    self.H = nn.Parameter(torch.randn(n_step * m, n_hidden).type(dtype))\n",
        "    self.d = nn.Parameter(torch.randn(n_hidden).type(dtype))\n",
        "    self.b = nn.Parameter(torch.randn(V).type(dtype))\n",
        "    self.W = nn.Parameter(torch.randn(n_step * m, V).type(dtype))\n",
        "    self.U = nn.Parameter(torch.randn(n_hidden, V).type(dtype))\n",
        "\n",
        "  def forward(self, X):\n",
        "    '''\n",
        "    X : [batch_size, n_step]\n",
        "    '''\n",
        "    X = self.C(X) # [batch_size, n_step, m]\n",
        "    X = X.view(-1, n_step * m) # [batch_szie, n_step * m]\n",
        "    hidden_out = torch.tanh(self.d + torch.mm(X, self.H)) # [batch_size, n_hidden]\n",
        "    output = self.b + torch.mm(X, self.W) + torch.mm(hidden_out, self.U)\n",
        "    return output\n",
        "model = NNLM()\n",
        "optim = optimizer.Adam(model.parameters(), lr=1e-3)\n",
        "criterion = nn.CrossEntropyLoss()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dtjKlEU-UP_l",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "759c1167-9a5c-4d1d-d621-2efb0e626454"
      },
      "source": [
        "for epoch in range(5000):\n",
        "  for batch_x, batch_y in loader:\n",
        "    pred = model(batch_x)\n",
        "    loss = criterion(pred, batch_y)\n",
        "\n",
        "    if (epoch + 1) % 1000 == 0:\n",
        "      print(epoch + 1, loss.item())\n",
        "    optim.zero_grad()\n",
        "    loss.backward()\n",
        "    optim.step()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "1000 0.03394317999482155\n",
            "2000 0.007886480540037155\n",
            "3000 0.0031574051827192307\n",
            "4000 0.0015269158175215125\n",
            "5000 0.0008065260481089354\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0Mqqhs4-UpaB",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "87c8a8c0-2413-4d78-8818-dc8f25e03682"
      },
      "source": [
        "# Pred\n",
        "pred = model(input_data).max(1, keepdim=True)[1]\n",
        "print([idx2word[idx.item()] for idx in pred.squeeze()])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['cat', 'coffee', 'milk']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "smWy_LLvU0sj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}